##### Functions to simulate processes used in figures for "Cost in Space: Debris and Collision Risk in the Orbital Commons"

# Function to generate a time series of open-access launches and other orbital stocks. Argument "start_at_SS" calculates the steady state (if it exists) and initializes the time series there. Argument "world_time" sets the time in the simulation world.
OA.ts <- function(S,D,T,t,opt=0,constraint,params,d_varying=NULL,grid_list=NULL,opt_interpolant=NULL,ncores=NULL,opt_policy_grid=NULL, constant_launch=-1, just_make_trajs=FALSE) {
	rates <- rep(0,length=(T+1))
	sats <- rep(0,length=(T+1))
	debs <- rep(0,length=(T+1))
	sats[1] <- S
	debs[1] <- D
	values <- rep(0,length=(T+1))
	fe_eqm <- params$fe_eqm
	init_period <- 1

	# The conditional logic here is an eldritch thing. I don't know why it's necessary but without it, you can't get both the stable basin for the optimal problem AND the projected time paths to be correct
	if(opt==1&just_make_trajs==FALSE) { #nolint
		gridloc_S <- which(opt_policy_grid$S==S)
		gridloc_D <- which(opt_policy_grid$D==D)
		row_num <- intersect(gridloc_S,gridloc_D)
		
		rates[1] <- opt_policy_grid$X[row_num]
		
		sats[2] <- S_(rates[1],S,D,params)
		debs[2] <- D_(rates[1],S,D,params)
		init_period <- 2
	}
	if(opt==1&just_make_trajs==TRUE) { #nolint
		state <-c(sats[1],debs[1])
		rates[1] <- opt_interpolant(state, threads=ncores)
		if(rates[1]=="NaN") {rates[1] <- 0}
		
		sats[2] <- S_(rates[1],S,D,params)
		debs[2] <- D_(rates[1],S,D,params)
		init_period <- 2
	}

	for(i in init_period:(T-1)) {
		if(length(d_varying)>0) {d <<- d_varying[i]}

		if(opt==-1) {X <- 0} # the "never launch anything" plan
		if(opt==0) {X <- eqm_launch_rate(S=sats[i],D=debs[i],params=params)}
		# if(constant_launch>0) {X <- constant_launch}
		# if(opt==1&sats[i]<max(grid_list$S)&debs[i]<max(grid_list$D)) {
		if(opt==1) {
			state <- c(sats[i],debs[i])
			X <- opt_interpolant(state, threads=ncores)
			if(X=="NaN") {X <- 0}
		}
		if(constant_launch>-1) {X <- constant_launch}
		if(length(X)==0) {X <- 0}
		if(length(X)>1) {X <- X[1]}
		if(constraint>0) {X <- pmin(X,constraint)}
		rates[i] <- X
		values[i] <- (sats[i] - (1/fe_eqm)*X) #assuming returns have been normalized to 1
		
		sats[(i+1)] <- S_(X,sats[i],debs[i],params)
		debs[(i+1)] <- D_(X,sats[i],debs[i],params)
		if(sats[(i+1)]=="NaN") {sats[(i+1)] <- sats[i]}

	}
	rates[T] <- rates[T-1] #max(rates[-T]) #since they should have jumped to the steady state (or be oscillating around it), I set the final period launch rate to whatever it was in the penultimate period to emulate a truncated infinite horizon path. max(rates[-T]) gets at a finite horizon path where the final period launches are constrained.
	sats[(T+1)] <- S_(rates[T],sats[T],debs[T],params)
	debs[(T+1)] <- D_(rates[T],sats[T],debs[T],params)
	#kessler[T] <- kess(debs[T])
	risk <- L(sats,debs,params)
	
	series <- cbind(seq(0:T),rates,debs,sats,values,risk)
	colnames(series) <- c("time","launches","debris","satellites","value","risk")

	return(series)
}

# helper function to grab a point nearest a specified value
getPt <- function(dpt,spt,dfrm) {
	pt_dloc <- which.min((dfrm$debris-dpt)^2)
	pt_sloc <- which.min((dfrm$satellites-spt)^2)
	pt <- which(dfrm$debris==dfrm$debris[pt_dloc]&dfrm$satellites==dfrm$satellites[pt_sloc])
	return(pt)
}

# helper function to generate launch paths from an initial condition
gen_launch_path <- function(init,T,constraint,params,dfrm) {
	launch_path <- OA.ts(init[1],init[2],150,0,0,constraint=constraint,params=params)
	lp_sats <- as.data.frame(launch_path)$satellites
	lp_debs <- as.data.frame(launch_path)$debris
	length(lp_sats) <- nrow(dfrm)
	length(lp_debs) <- nrow(dfrm)
	clp <- data.frame(lp_sats=lp_sats,lp_debs=lp_debs)
	return(clp)
}


# helper function to calculate the basin of attraction for figure 10abc
gen_fig10_data <- function(params, SD_grid, opt=0, opt_interpolant=NULL, ncores=NULL, opt_policy_grid=NULL, constant_launch=-1) {
	fe_eqm <- params$fe_eqm
	T <- params$T

	iter_time <- proc.time()[3]
	message("Beginning basin calculation...")
	
	input_dfrm <- data.frame(SD_grid) 
	# %>% 
	# 			mutate(oa_coll_rate = eqm_sat_stock(debris,params)) %>%
	# 			mutate(sat_null = sats_SS(satellites,debris,params,constant_launch=constant_launch)) %>%
	# 			mutate(deb_null = debs_SS(satellites,debris,params,constant_launch=constant_launch))

	path_convergence <- rep(-1,length=nrow(input_dfrm))

	series_cl <- makeCluster(ncores)
	registerDoParallel(series_cl)
	path_convergence <- foreach(i=1:nrow(input_dfrm), .export=ls(envir=globalenv()), .inorder=TRUE, .combine=rbind) %dopar% {
		S <- input_dfrm$satellites[i]
		D <- input_dfrm$debris[i]
		# Switching through various launch policies
		if(opt==-1) series <- data.frame(OA.ts(S,D,T,0,opt,0,params, constant_launch=constant_launch))
		if(opt==0) series <- data.frame(OA.ts(S,D,T,0,opt,0,params, constant_launch=constant_launch))
		if(opt==1) series <- data.frame(OA.ts(S,D,T,0,opt,0,params,opt_interpolant=opt_interpolant,ncores=1,opt_policy_grid=opt_policy_grid))
		# Check convergence to steady state
		result <- (sum(diff(series$debris[(T-20):T]))>0)&(mean(series$satellites[(T-20):T])<1e-4) # Check if the debris series is growing and the satellite stock is zero. If yes, then Kessler Syndrome has probably occurred. Why did I need the satellite stock to be zero?
		result <- (sum(diff(series$debris[(T-20):T]))>1e-2) # Check if the debris series is growing only.
		# print(result)
		if(is.na(result)==TRUE) {output <- 0} 
		if(result>=1e-4) {output <- 0}
		if(result<1e-4) {output <- 0.25}
		#if(result==FALSE) {output <- 0.25}
		output 
	}
	on.exit(stopCluster(series_cl))

	input_dfrm <- cbind(input_dfrm, path_convergence)

	iter_time <- round(proc.time()[3] - iter_time,3)
	message("Finished basin calculation. Total time taken: ", iter_time, " seconds.")

	return(input_dfrm)
}

# helper function to generate plots for figure 10abc
figure_10_gen <- function(params, fig10_dfrm, opt=0, FAR=1, labels=NULL, eqm_line=0, make_nullclines=0, s_scale_factor = NULL, d_scale_factor=NULL, natural_kessler=NULL) {	
	if(opt==-1) {oass <- ss_finder_const(c(1,1), params)}
	if(opt==0) {oass <- ss_finder(c(1,1), params)}
	if(opt==1) {oass <- opt_ss_finder(c(0.1,0.1), params)}

	basin_points <- which(fig10_dfrm$path_convergence>0)

	# if(!is.null(natural_kessler)){
	# 	avoidable_kess_points <- which(natural_kessler$path_convergence>0)
	# }

	if(make_nullclines==1) {
		fig10_dfrm <- fig10_dfrm %>% 
		mutate(
			S_next = S_(X,satellites,debris,params),
			D_next = D_(X,satellites,debris,params),
			sat_null = (S_next - satellites)/10,
			deb_null = (D_next - debris)/10
		)
	}

	# Scale the axes if desired
	if(!is.null(s_scale_factor)){
		fig10_dfrm$satellites <- fig10_dfrm$satellites * s_scale_factor
		fig10_dfrm$S_next <- fig10_dfrm$S_next * s_scale_factor
		fig10_dfrm$X <- fig10_dfrm$X * s_scale_factor
	}

	if(!is.null(d_scale_factor)){
		fig10_dfrm$debris <- fig10_dfrm$debris * d_scale_factor
		fig10_dfrm$D_next <- fig10_dfrm$D_next * d_scale_factor
	}

	if(length(basin_points)>0) {
		figure_10 <- ggplot(data=fig10_dfrm, aes(x=debris)) + 
					geom_point(data=fig10_dfrm[basin_points,], aes(x=debris,y=satellites), color="#B3B3B3")
	}
	if(length(basin_points)==0) {
		figure_10 <- ggplot(data=fig10_dfrm, aes(x=debris))
	}

	if(eqm_line==1) {
		figure_10 <- figure_10 +
			geom_line(aes(y=oa_coll_rate), size=1)
	}


	if(FAR==1){
		figure_10 <- figure_10 +
						geom_line(aes(y=FAR_boundary), color="firebrick4", size=1, linetype="dashed") + 
						xlim(0,10*d_scale_factor) + ylim(0,2.5*s_scale_factor) +
						stat_contour(aes(x=debris,y=satellites,z=sat_null, group=..level..), size=1, breaks=c(0), color="#313695", alpha=0.75) +
						stat_contour(aes(x=debris,y=satellites,z=deb_null, group=..level..), size=1, breaks=c(0), color="#A50026", alpha=0.75)
				}

	if(FAR==0){
	figure_10 <- figure_10 +
					xlim(0,10*d_scale_factor) + ylim(0,2.5*s_scale_factor) +
					stat_contour(aes(x=debris,y=satellites,z=sat_null, group=..level..), size=1, breaks=c(0), color="#313695", alpha=0.75) +
					stat_contour(aes(x=debris,y=satellites,z=deb_null, group=..level..), size=1, breaks=c(0), color="#A50026", alpha=0.75)
				}

	#if(length(basin_points)>0) {
		# figure_10 <- figure_10 + geom_point(aes(x=oass[2], y=oass[1]), size=5) + xlim(0,10) + ylim(0,2.5)
	#}

	figure_10 <- figure_10 +
					geom_segment(x=0,y=0, xend=30*d_scale_factor,yend=0, size=0.15) + geom_segment(x=0,y=0, xend=0,yend=30*s_scale_factor, size=0.15) +
					ylab("Satellites") + xlab("Debris") +
					theme_bw() + theme(text=element_text(family="Arial", size=15), legend.position="none", panel.grid.minor = element_blank()) +
					scale_x_continuous(labels = function(x) comma(x/1000)) +
					scale_y_continuous(labels = function(x) comma(x/1000)) #+   
					# scale_x_continuous(breaks=c(0), labels=c(0), limits=c(0,10)) + 
					# scale_y_continuous(breaks=c(0), labels=c(0), limits=c(0,2.5))

	if(labels==1) {
		figure_10 <- figure_10 +
			geom_text(label = "Oa", x=(oass[2]-0.5), y=(oass[1]-0.1), size=7) +
			geom_text(label = "Fb", x=3.75, y=2.1, color="firebrick4", size=7) +
			geom_text(label = "Dn",	x = 6, y = 1, color="#A50026", size=7) +
			geom_text(label = "Sn",	x = 7.35, y = 0.15, color="#313695", size=7) +
			geom_text(label = "Em",	x = 8.8, y = 0.275, color="black", size=7)
	}


	figure_10 <- figure_10 + coord_flip()
					
	return(figure_10)
}

# Continuation value interpolation created from a grid of function values and evaluation nodes. First argument is the updated state.
continuation_value <- function(updated_state, grid_list, contval_dfrm) {
    state <- updated_state
    cv_fn <- ipol(contval_dfrm$fleet_value, grid=grid_list, method="multilinear") # ipol is an interpolator from the chebpol package which produces a function that approximates the function described by "contval_dfrm" on the grid points given in "grid_list".
    value <- cv_fn(state)

    return(value)
}

# Planner's pre-value function. The pre-value function is the planner's objective function, expressed recursively, where the continuation value term is assumed to have been optimized, leaving a simple two-period problem for the planner to solve. The term "pre-value function" comes from Kimball (2014) -- a very good paper about optimal control theory for economists (can provide reference)
pre_value_fn <- function(X,S,D,params,contval_dfrm,grid_list) {
	discount_factor <- params$discount
	S_next <- S_(X,S,D,params)
	D_next <- D_(X,S,D,params)
	state_next <- c(S_next,D_next)

	prevalue <- one_p_return(X,S,params) + discount_factor*continuation_value(state_next, grid_list, contval_dfrm)

	return(prevalue)
}

### Propagates time series for satellite and debris stocks from initial_condition (a pair (S0,D0)) for n_periods periods. Takes grid_list and launch_policy as inputs to generate an interpolant for the launch rate.
generate_time_series <- function(initial_condition,launch_policy,grid_list,n_periods,params) {
	S0 <- initial_condition[1]
	D0 <- initial_condition[2]

    launch_sequence <- rep(NA, length.out=n_periods)
    satellite_sequence <- rep(NA, length.out=n_periods)
    debris_sequence <- rep(NA, length.out=n_periods)

    launch_policy_fn <- ipol(launch_policy, grid=grid_list, method="multilinear")
    
    satellite_sequence[1] <- S0
    debris_sequence[1] <- D0
    launch_sequence[1] <- launch_policy_fn(c(satellite_sequence[1], debris_sequence[1]))

    for(i in 2:n_periods) {
    	satellite_sequence[i] <- S_(launch_sequence[i-1], satellite_sequence[i-1], debris_sequence[i-1],params)
    	debris_sequence[i] <- D_(launch_sequence[i-1], satellite_sequence[i-1], debris_sequence[i-1],params)
    	launch_sequence[i] <- launch_policy_fn(c(satellite_sequence[i], debris_sequence[i]))
    }

	output <- data.frame(time = c(1:n_periods),
    					X = launch_sequence,
    					S = satellite_sequence,
    					D = debris_sequence)

	return(output)
}

### Implements dynamic programming solver
dp_solver <- function(params,contval_dfrm,grid_list,steps_guarantee=NULL) {

	VFI_count <- 1
	# epsilon <- 0.01*mean(contval_dfrm$fleet_value)
	policy_epsilon <- 0.01*mean(contval_dfrm$X)
	epsilon <- 0.01*mean(contval_dfrm$X) # since the value function is just going to diverge, try checking convergence in the policy function space.
	delta <- 100*epsilon

	# initialize the objects here outside the loop
	new_contval_mat <- matrix(c(as.numeric(contval_dfrm$fleet_value)), ncol=1)
	new_contval_mat <- matrix(c(as.numeric(contval_dfrm$X)), ncol=1)

  # VFI_plot <-	ggplot(data = contval_dfrm, aes(x = D, y = S)) +
  # 			#geom_tile(aes(fill = fleet_value)) +
  # 			geom_contour_filled(aes(z = fleet_value)) +
  # 			labs(x = "Debris", y = "Satellites", fill = "Fleet value") + ggtitle("VFI convergence") + theme_bw()
  # policy_plot <- ggplot(data = contval_dfrm, aes(x = D, y = S)) +
  # 			#geom_tile(aes(fill = X)) +
  # 			geom_contour_filled(aes(z = X)) +
  # 			labs(x = "Debris", y = "Satellites", fill = "Launch rate") + ggtitle("Policy function") + theme_bw()
  # summary_plot <- plot_grid(VFI_plot, policy_plot, ncol=1)
  # print(summary_plot)

	iter_time <- proc.time()[3]
	while(delta > epsilon) {
		message("Beginning value function iteration ", VFI_count, " with ", nrow(contval_dfrm), " grid points. Convergence threshold is ", epsilon, ".")

		# initialize this one inside the loop so that it gets updated after each iteration
		old_contval_mat <- matrix(c(as.numeric(contval_dfrm$fleet_value)), ncol=1)
		old_policy_mat <- matrix(c(as.numeric(contval_dfrm$X)), ncol=1)

		VFI_results <- foreach(i=1:nrow(contval_dfrm), .export=ls(globalenv()), .packages=c("chebpol","simFnsEqns"), .combine=rbind, .inorder=TRUE) %dopar% {

			output <- optim(par = contval_dfrm$X[i], fn = pre_value_fn, S = contval_dfrm$S[i], D = contval_dfrm$D[i], params = params, contval_dfrm = contval_dfrm, grid_list = grid_list, control = list(fnscale=-1), method = "L-BFGS-B", lower = 0, upper = params$Xbar)

			X_new <- output$par
			W_new <- output$value

			# # Prevent the value function from falling off a cliff forever
			# if(W_new < -666) W_new <- -666 
		
			return(c(contval_dfrm$S[i],contval_dfrm$D[i],X_new,W_new))
		}

		# Assign the new values, ensure the types of the old and new value functions are the same -- should both be matrices
		new_contval_mat <- c(as.numeric(VFI_results[,4]))
		new_policy_mat <- c(as.numeric(VFI_results[,3]))

		# Prevent the value function from falling off a cliff forever
		# new_contval_mat[which(new_contval_mat < -666)] <- -666 

		# Calculate the sup-norm of the change in the value function. If we've done more than 10 iterations, switch to the policy functions to check convergence.
		if(VFI_count<10) delta <- max(abs(new_contval_mat - old_contval_mat))
		if(VFI_count>=10) {
			delta <- max(abs(new_policy_mat - old_policy_mat))
			# epsilon <- policy_epsilon
		}
		if(VFI_count>=50) delta <- 0

		iter_time <- round(proc.time()[3] - iter_time,3)
		message("Finished iteration ", VFI_count, ". Delta is ", delta, ". Total time taken: ", iter_time, " seconds.")

		if(!is.null(steps_guarantee)) { # nolint
			if(VFI_count<steps_guarantee){
				delta <- 100*epsilon
				message("Steps guarantee is ", steps_guarantee, ". Continuing to iterate.")
				}
		}

		VFI_count <- VFI_count + 1

		contval_dfrm$X <- VFI_results[,3]
		contval_dfrm$fleet_value <- VFI_results[,4]

	VFI_plot <- ggplot(data = contval_dfrm, aes(x = D, y = S)) +
			#geom_tile(aes(fill = fleet_value)) +
  			geom_contour_filled(aes(z = fleet_value)) +
			labs(x = "Debris", y = "Satellites", fill = "Fleet value") + ggtitle("VFI convergence") + theme_bw()
	policy_plot <- ggplot(data = contval_dfrm, aes(x = D, y = S)) +
				#geom_tile(aes(fill = X)) +
  				geom_contour_filled(aes(z = X)) +
				labs(x = "Debris", y = "Satellites", fill = "Launch rate") + ggtitle("Policy function") + theme_bw()

	ggsave(
		filename = paste0("VFI_plots/VFI_plot_",VFI_count,".png"),
		plot = VFI_plot | policy_plot,
		width = 10,
		height = 5,
		units = "in",
		dpi = 300
	)
	# summary_plot <- plot_grid(VFI_plot, policy_plot, ncol=1)
	# print(summary_plot)
	}

	return(contval_dfrm) 
}

# Function to just generate a time series of length T given a launch vector of length T. X is a vector of length T, S,D,T are all scalars, params is a vector of parameters.
gen_fts <- function(X,S,D,T,params) {
	sats <- rep(0,length=(T+1))
	debs <- rep(0,length=(T+1))
	sats[1] <- S
	debs[1] <- D
	for(period in 1:T) {
		sats[(period+1)] <- S_(X[period], sats[period], debs[period], params)
		debs[(period+1)] <- D_(X[period], sats[period], debs[period], params)
	}
	output <- data.frame(launches=c(X,0), sats=sats, debs=debs)
	# print(output)
	return(output)
}

# gen_fts(c(0,0,0),1,0,3,params)

# Function to calculate the fleet NPV of a satellite time series
fleet_npv <- function(input, params) {
	T <- nrow(input)
	X <- input[,1]
	S <- input[,2]
	PV_series <- rep(0,length.out=T)
	discount <- params$discount

	for(period in 1:T) {
		PV_series[period] <- (discount^(period-1))*one_p_return(X[period], S[period], params)		
	}
	# print(PV_series)
	NPV <- sum(PV_series)
	return(NPV)
}

# Finite-horizon plan objective function for use with guess_builder
objective_fts <- function(X,S,D,T,params) {
	time_series_output <- gen_fts(X,S,D,T,params)
	value <- fleet_npv(time_series_output,params)
	# print(value)
	return(value)
}

# Finite-horizon plan objective function for use with guess_builder
objective_fts_cpp <- function(X,S,D,T,params) {
	time_series_output <- gen_fts_cpp(X,S,D,T,params)
	# print(time_series_output)
	value <- fleet_npv_cpp(time_series_output,params)
	# print(value)
	return(value)
}

# objective_fts(c(0,0,0),1,0,3,params)


# expands (or shrinks) a solved value function using linear interpolation on the smaller grid to project to the larger grid
grid_grow <- function(solved_values, new_grid_dfrm, policy_iteration=0, PFI_T=10, params=NULL) {
    value_fn_old = solved_values$fleet_value
    policy_fn_old = solved_values$X

    old_grid_list = list(S=unique(solved_values$S), D=unique(solved_values$D))

    vfn_o = ipol(value_fn_old, grid=old_grid_list, method="multilinear")
    pfn_o = ipol(policy_fn_old, grid=old_grid_list, method="multilinear")

    colnames(new_grid_dfrm) = NULL
    new_grid_dfrm = t(as.matrix(new_grid_dfrm))

    vfn_n = vfn_o(new_grid_dfrm)
    pfn_n = pfn_o(new_grid_dfrm)

    new_values = data.frame(t(new_grid_dfrm), X=pfn_n, fleet_value=vfn_n)
    colnames(new_values) <- c("S","D","X","fleet_value")

    if(policy_iteration==1) {
    	message("Doing policy iteration...")
    	new_values_PFI <- new_values
    	new_grid_list <- list(S=unique(new_values$S), D=unique(new_values$D))
    	for(period in 1:PFI_T){
	    	new_values_PFI <- new_values_PFI %>% 
	    		mutate(	s_next = S_(X,S,D,params),
	    				d_next = D_(X,S,D,params),
	    				contval = vfn_o(t(as.matrix(data.frame(s_next,d_next)))),
	    				fleet_value = one_p_return(X,S,params)+params$discount*contval) %>%
	    		select(S,D,X,fleet_value)
	    		vfn_o <- ipol(new_values_PFI$fleet_value, grid=new_grid_list, method="multilinear")
    	}
    	new_values <- new_values_PFI
    }

    return(new_values)
}


# Builds a policy function guess
guess_builder <- function(params, grid_dfrm, grid_dfrm_big, ncores=10, T=40, policy_iteration=0, PFI_T=10, label="test", fig4=FALSE) {

	init_launchpath <- rep(0, length.out=T)

	guess_time <- proc.time()[3]
	message("Building guess on small grid...")
	guess_dfrm <- foreach(i=1:nrow(grid_dfrm), .export=ls(globalenv()), .packages=c("chebpol","simFnsEqns"), .combine=rbind, .inorder=TRUE) %dopar% {
		# output <- optim(par = init_launchpath, fn = objective_fts, S = grid_dfrm$S[i], D = grid_dfrm$D[i], T=T, params = params, control = list(fnscale=-1), method = "L-BFGS-B", lower = 0, upper = params$Xbar)
		output <- optim(par = init_launchpath, fn = objective_fts_cpp, S = grid_dfrm$S[i], D = grid_dfrm$D[i], T=T, params = params, control = list(fnscale=-1), method = "L-BFGS-B", lower = 0, upper = params$Xbar)
		# output <- optimParallel(par = init_launchpath, fn = objective_fts, S = grid_dfrm$S[i], D = grid_dfrm$D[i], T=T, params = params, control = list(fnscale=-1), method = "L-BFGS-B", lower = 0, upper = params$Xbar, parallel=list(cl=gb_cl))
		X_new <- output$par[1]
		W_new <- output$value
		# print(c(grid_dfrm$S[i],grid_dfrm$D[i],X_new,W_new))
		return(c(grid_dfrm$S[i],grid_dfrm$D[i],X_new,W_new))
	}

	guess_time_int <- proc.time()[3] - guess_time	
	message("Finished. Time taken to solve on small grid: ", round(guess_time_int,3), " seconds.")

	guess_time <- proc.time()[3]
	message("Scaling small grid up...")

	# Test this piece separately -- outputs are looking like a list
	colnames(guess_dfrm) <- c("S","D","X","fleet_value")
	guess_dfrm <- as.data.frame(guess_dfrm)
	rownames(guess_dfrm) <- NULL
	guess_dfrm_big <- grid_grow(guess_dfrm, grid_dfrm_big, policy_iteration, PFI_T, params)

	guess_time <- proc.time()[3] - guess_time
	message("Finished. Scaling time: ", round(guess_time,3) ," seconds.")

	### Make output plots
	VFI_plot_small <-	ggplot(data = guess_dfrm, aes(x = D, y = S)) +
  			geom_contour_filled(aes(z = fleet_value)) +
  			labs(x = "Debris", y = "Satellites", fill = "Fleet value", title = "Guess VF small") + theme_bw()
  	policy_plot_small <- ggplot(data = guess_dfrm, aes(x = D, y = S)) +
  			geom_contour_filled(aes(z = X)) +
  			labs(x = "Debris", y = "Satellites", fill = "Launch rate", title = "Guess PF small") + theme_bw()

	VFI_plot_big <-	ggplot(data = guess_dfrm_big, aes(x = D, y = S)) +
  			geom_contour_filled(aes(z = fleet_value)) +
  			labs(x = "Debris", y = "Satellites", fill = "Fleet value", title = "Guess VF big") + theme_bw()
  	policy_plot_big <- ggplot(data = guess_dfrm_big, aes(x = D, y = S)) +
  			geom_contour_filled(aes(z = X)) +
  			labs(x = "Debris", y = "Satellites", fill = "Launch rate", title = "Guess PF big") + theme_bw()

  	summary_plot <- (VFI_plot_small | policy_plot_small) / (VFI_plot_big | policy_plot_big) + plot_annotation(tag_levels='a')

	if(fig4==TRUE) {
		png(width=800,height=400,filename=paste0("../../images/fig_4_",label,"_guess.png"))
		print(summary_plot)
		dev.off()
	}
	
	if(fig4==FALSE){
		print(getwd())
		png(width=800,height=400,filename=paste0("../images/fig_6_",label,"_guess.png"))
		print(summary_plot)
		dev.off()
	}

	return(guess_dfrm_big)

	# return(guess_dfrm) 
}

# function to make a S-D grid
make_sd_grid <- function(S_min, S_max, S_n, D_min, D_max, D_n){
  S_element <- seq(S_min,S_max,length.out=S_n)
  D_element <- seq(D_min,D_max,length.out=D_n)
  SD_grid <- expand.grid(S_element,D_element)
  colnames(SD_grid) <- c("S","D")
  return(SD_grid)
}

# Function to take a grid and parameters, build a guess, save that guess, and then run the dynprog solver
guess_and_solve <- function(params, grid_dfrm_small, grid_dfrm_big, grid_list, label, policy_iteration=0, PFI_T=10, steps_guarantee=NULL){
  message("Generating guess...")
  guess_dfrm <- guess_builder(params, grid_dfrm=as.data.frame(grid_dfrm_small), grid_dfrm_big=as.data.frame(grid_dfrm_big), T=params$T, label=label, policy_iteration=policy_iteration, PFI_T=PFI_T)
  fwrite(guess_dfrm, paste0("../data/VFIguess_",nrow(grid_dfrm_big),"_",label,".csv"))
  message("Solving DP problem...")
  output <- dp_solver(params, guess_dfrm, grid_list,steps_guarantee=steps_guarantee)
  fwrite(output, paste0("../data/VFI_result_",label,"_.csv"))
  return(output)
}

# Wrapper function to solve for an optimal policy
solve_opt_policy <- function(params, small_guess_grid, larger_vfi_grid, grid_list, ncores=60, label="default",steps_guarantee=NULL){
	solver_time <- proc.time()[3]
	message("Initializing workers...")
	vfi_cl <- makeCluster(ncores)
	registerDoParallel(vfi_cl)
	
	opt_policy <- guess_and_solve(params, SD_grid_s, SD_grid, grid_list, label, policy_iteration=0, PFI_T=10,steps_guarantee=steps_guarantee)

	stopCluster(vfi_cl)
	solver_time <- proc.time()[3] - solver_time
	message("Time to solve planner problem(s): ",round(solver_time/60,3)," minutes.")

	return(opt_policy)
}

# Function to expand planner's policy to a larger grid. Idea is that these additional states are where planner would choose X=0, so the policy values there can be filled in with 0s. This is useful for minimizing compute time on states where the policy is known and getting a larger policy grid for calculating basin of attraction and plotting phase diagram.
expand_policy <- function(S_element, D_element, grid_list, solved_dfrm) {
	# expand planner's dfrm by a multiple of the dimensions (elements as multiples -> expand -> join), fill 0 in new elements
	## build bigger grid. because of how seq() works, you need to subtract 1 from the length for every additional multiple you're extending the grid. So if you're doubling the range of a grid with length 10, the bigger grid should be length 2*10 - 1. The general formula for an original grid of length M being extended by a factor of K is K*M - (K-1).
	message("Expanding grid area...")
	M <- length(S_element)
	K <- 10/max(S_element)
	S_element_big <- seq(0,10,length.out=(K*M - (K-1)))
	M <- length(D_element)
	K <- 10/max(D_element)
	D_element_big <- seq(0,10,length.out=(K*M - (K-1)))

	SD_grid_big <- expand_grid(S_element_big,D_element_big)
	grid_list <- list(S = S_element_big, D = D_element_big)
	colnames(SD_grid_big) <- c("S","D")

	expanded_solve <- full_join(SD_grid_big, solved_dfrm, by=c("S","D"))
	expanded_solve$X[is.na(expanded_solve$X)] <- 0
	return(expanded_solve)
}

# Wrapper function to calculate basins of attraction
find_stable_basins <- function(params, optimal_policy, SD_grid, grid_list, ncores, label="benchmark"){

	message("Expanding planner's policy to larger grid...")
	# Expand planner's policy to larger grid over inaction region
	optimal_policy <- expand_policy(S_element=unique(SD_grid$S), D_element=unique(SD_grid$D), grid_list=grid_list, solved_dfrm=optimal_policy)

	message("Generating interpolants for optimal policy...")
	grid_list_big <- list(S = unique(optimal_policy$S), D = unique(optimal_policy$D))
	opt_launch_rate_ipol <- ipol(optimal_policy$X, grid=grid_list_big, method="multilinear")

	SD_grid_convergence_input <- optimal_policy
	colnames(SD_grid_convergence_input)[1:2] <- c("satellites","debris")

	message("Calculating optimal steady state basin of attraction...")
	basin_time <- proc.time()[3]
	opt_basin_dfrm <- gen_fig10_data(params, SD_grid_convergence_input, opt=1, opt_interpolant=opt_launch_rate_ipol, ncores=80, opt_policy_grid=optimal_policy)
	basin_time <- proc.time()[3] - basin_time
	message("Time to compute basin: ",round(basin_time/60,3)," minutes.")
	fwrite(opt_basin_dfrm, paste0("../data/opt_",label,"_basin.csv"))

	# Update SD_grid_covnergence_input to be the open access policy
	oa_SD_grid_convergence_input <- SD_grid_convergence_input %>% 
		select(-X) %>% # testing this -- without it the code runs but the results are wrong. the open access policy shouldn't be the same as the optimal policy, but the optimal policy is being used.
		mutate(X = eqm_launch_rate(satellites,debris,params))

	message("Calculating open access stable steady state basin of attraction...")
	basin_time <- proc.time()[3]
	oa_basin_dfrm <- gen_fig10_data(params, oa_SD_grid_convergence_input, opt=0, ncores=80)
	basin_time <- proc.time()[3] - basin_time
	message("Time to compute basin: ",round(basin_time/60,3)," minutes.")

	fwrite(oa_basin_dfrm, paste0("../data/oa_",label,"_basin.csv"))

	return(list(opt=opt_basin_dfrm, oa=oa_basin_dfrm))
}


make_policy_phase_plots <- function(oa_dfrm, opt_dfrm, s_scale_factor=NULL, d_scale_factor=NULL){

	oa_opt_big <- left_join(oa_dfrm, opt_dfrm, by=c("satellites","debris"), suffix=c(".oa",".opt"))

	# Generate phase portrait variables ((S' - S) / h) from basin dfrms. Use data.table to speed up calculations.
	oa_opt_big <- oa_opt_big %>% 
		mutate(
			X.oa = eqm_launch_rate(satellites,debris,params)) %>%
		mutate(
			opt_S_next = S_(X.opt,satellites,debris,params),
			opt_D_next = D_(X.opt,satellites,debris,params),
			opt_dS = (opt_S_next - satellites)/10,
			opt_dD = (opt_D_next - debris)/10,
			oa_S_next = S_(X.oa,satellites,debris,params),
			oa_D_next = D_(X.oa,satellites,debris,params),
			oa_dS = (oa_S_next - satellites)/10,
			oa_dD = (oa_D_next - debris)/10
		)

	# Scale the axes if desired
	if(!is.null(s_scale_factor)){
		oa_opt_big$satellites <- oa_opt_big$satellites * s_scale_factor
		oa_opt_big$opt_S_next <- oa_opt_big$opt_S_next * s_scale_factor
		oa_opt_big$oa_S_next <- oa_opt_big$oa_S_next * s_scale_factor
		oa_opt_big$oa_dS <- oa_opt_big$oa_dS * s_scale_factor
		oa_opt_big$opt_dS <- oa_opt_big$opt_dS * s_scale_factor
		oa_opt_big$X.oa <- oa_opt_big$X.oa * s_scale_factor
		oa_opt_big$X.opt <- oa_opt_big$X.opt * s_scale_factor
	}

	if(!is.null(d_scale_factor)){
		oa_opt_big$debris <- oa_opt_big$debris * d_scale_factor
		oa_opt_big$opt_D_next <- oa_opt_big$opt_D_next * d_scale_factor
		oa_opt_big$oa_D_next <- oa_opt_big$oa_D_next * d_scale_factor
		oa_opt_big$oa_dD <- oa_opt_big$oa_dD * d_scale_factor
		oa_opt_big$opt_dD <- oa_opt_big$opt_dD * d_scale_factor
	}



	# brewer.pal(n = 11, name = "RdYlBu")
	fig_opt_policy <- ggplot(data=oa_opt_big, aes(x=debris,y=satellites)) +
		geom_tile(data=oa_opt_big[X.opt>0|path_convergence.opt>0,], aes(x=debris,y=satellites,fill=X.opt, alpha=(X.opt>0))) +
		labs(x="Debris", y="Satellites", title="Planner's policy", fill="Launch rate") +
		scale_x_continuous(labels=comma) + 
		scale_y_continuous(labels=comma) +
		theme_bw() +
		stat_contour(aes(z=opt_dS, group=after_stat(level)), linewidth=1, breaks=c(0), color="#313695", alpha=0) +
		stat_contour(aes(z=opt_dD, group=after_stat(level)), linewidth=1, breaks=c(0), color="#A50026", alpha=0) +
		scale_fill_distiller(palette="Greys", direction=1, limits = range(oa_opt_big$X.oa)) +
		theme(
			text = element_text(size=30, family="Arial"),
			# axis.text.x = element_blank(),
			legend.key.size = unit(3, 'lines'),
			legend.position = "bottom"
		) + guides(alpha="none", fill="none") + coord_flip()

	fig_oa_policy <- ggplot(data=oa_opt_big, aes(x=debris,y=satellites)) +
		geom_tile(data=oa_opt_big[X.oa>0|path_convergence.oa>0,], aes(x=debris,y=satellites,fill=X.oa, alpha=(X.oa>0))) +
		labs(x="Debris", y="Satellites", title="Open-access policy", fill="Launch rate") +
		stat_contour(aes(z=opt_dS, group=after_stat(level)), linewidth=1, breaks=c(0), color="#313695", alpha=0) +
		stat_contour(aes(z=opt_dD, group=after_stat(level)), linewidth=1, breaks=c(0), color="#A50026", alpha=0) +
		scale_x_continuous(labels=comma) + 
		scale_y_continuous(labels=comma) +
		theme_bw() +
		scale_fill_distiller(palette="Greys", direction=1, limits = range(oa_opt_big$X.oa)) +
		theme(
			text = element_text(size=30, family="Arial"),
			# axis.text.x = element_blank(),
			legend.key.size = unit(3, 'lines'),
			legend.position = "bottom"
		) + coord_flip() + guides(alpha="none", fill="none")

	# brewer.pal(n = 8, name = "Set2")
	fig_opt_phase <- ggplot(data=oa_opt_big, aes(x=debris,y=satellites)) +
		geom_point(data=oa_opt_big[path_convergence.opt>0,], aes(x=debris,y=satellites), color="#B3B3B3") +
		stat_contour(aes(z=opt_dS, group=after_stat(level)), linewidth=1, breaks=c(0), color="#313695") +
		stat_contour(aes(z=opt_dD, group=after_stat(level)), linewidth=1, breaks=c(0), color="#A50026") +
		labs(x="Debris", y="Satellites", title="Planner's phase diagram & stable basin") +
		scale_x_continuous(labels=comma) + 
		scale_y_continuous(labels=comma) +
		theme_bw() +
		theme(
			text = element_text(size=30, family="Arial"),
			# axis.text.x = element_blank(),
			legend.key.size = unit(3, 'lines'),
			legend.position = "bottom"
		) + coord_flip()

	fig_oa_phase <- ggplot(data=oa_opt_big, aes(x=debris,y=satellites)) +
		geom_point(data=oa_opt_big[path_convergence.oa>0,], aes(x=debris,y=satellites), color="#B3B3B3") +
		stat_contour(aes(z=oa_dS, group=after_stat(level)), linewidth=1, breaks=c(0), color="#313695") +
		stat_contour(aes(z=oa_dD, group=after_stat(level)), linewidth=1, breaks=c(0), color="#A50026") +
		geom_text(label = "Dn",	x = 6.35, y = 0.75, color="#A50026", size=7) +
		geom_text(label = "Sn",	x = 7.65, y = 0.3, color="#313695", size=7) +
		labs(x="Debris", y="Satellites", title="Open-access phase diagram & stable basin") +
		scale_x_continuous(labels=comma) + 
		scale_y_continuous(labels=comma) +
		theme_bw() +
		theme(
			text = element_text(size=30, family="Arial"),
			# axis.text.x = element_blank(),
			legend.key.size = unit(3, 'lines'),
			legend.position = "bottom"
		) + coord_flip() #+
		# # horizontal arrow from (0.1,0.1)
		# geom_segment(	x=0.1, 
		# 				y=0.1, 
		# 				xend=(0.1+0.15), 
		# 				yend=(0.1), 
		# 				arrow=arrow(length= unit(0.25,"cm")), linewidth=1	) +
		# # vertical arrow from (0.1,0.1)
		# geom_segment(	x=0.1, 
		# 				y=0.1, 
		# 				xend=(0.1), 
		# 				yend=(0.1+0.15), 
		# 				arrow=arrow(length= unit(0.25,"cm")), linewidth=1	) +

		# # horizontal arrow from (3.5,0.5)
		# geom_segment(	x=2, 
		# 				y=0.1, 
		# 				xend=(2-0.15), 
		# 				yend=(0.1), 
		# 				arrow=arrow(length= unit(0.25,"cm")), linewidth=1	) +
		# # vertical arrow from (3.5,0.5)
		# geom_segment(	x=2, 
		# 				y=0.6, 
		# 				xend=(2), 
		# 				yend=(0.1+0.15), 
		# 				arrow=arrow(length= unit(0.25,"cm")), linewidth=1	) +

		# # horizontal arrow from (3.5,3.75)
		# geom_segment(	x=1.15, 
		# 				y=3.75, 
		# 				xend=(1.15-0.15), 
		# 				yend=(3.75), 
		# 				arrow=arrow(length= unit(0.25,"cm")), linewidth=1	) +
		# # vertical arrow from (3.5,3.75)
		# geom_segment(	x=1.15, 
		# 				y=3.75, 
		# 				xend=(1.15), 
		# 				yend=(1.15-0.15), 
		# 				arrow=arrow(length= unit(0.25,"cm")), linewidth=1	) +

		# # horizontal arrow from (0.5,3.75)
		# geom_segment(	x=0.5, 
		# 				y=3.75, 
		# 				xend=(0.5+0.15), 
		# 				yend=(3.75), 
		# 				arrow=arrow(length= unit(0.25,"cm")), linewidth=1	) +
		# # vertical arrow from (0.5,3.75)
		# geom_segment(	x=0.5, 
		# 				y=3.75, 
		# 				xend=(0.5), 
		# 				yend=(3.75-0.15), 
		# 				arrow=arrow(length= unit(0.25,"cm")), linewidth=1	)

	return(list(opt_policy=fig_opt_policy, oa_policy=fig_oa_policy, opt_phase=fig_opt_phase, oa_phase=fig_oa_phase))

}
